{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L2PiRHAx0u3c"
      },
      "source": [
        "<a target=\"_blank\" href=\"https://colab.research.google.com/github/qianniucity/llm_notebooks/blob/main/notebooks/Evaluate_RAG_on_Synthetic_Data.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "本笔记本通过生成合成数据集并使用 Ragas 框架计算其评估指标来探索基于 RAG 的系统的评估。\n",
        "\n",
        "假设您只需进行少量配置，本笔记本可按原样运行：\n",
        "- 可访问 VertexAI 的 GCP 账户\n",
        "- 一个 Pinecone 账户"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jL9_Jj0IvuB1"
      },
      "outputs": [],
      "source": [
        "%pip install langchain pinecone-client wikipedia datasets ragas\n",
        "\n",
        "\n",
        "! wget -O testset_generator.py https://raw.githubusercontent.com/ahmedbesbes/rag-evaluation-synthetic/main/testset_generator.py\n",
        "! wget -O rag.py https://raw.githubusercontent.com/ahmedbesbes/rag-evaluation-synthetic/main/rag.py"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T8HMj3_7wzVC"
      },
      "outputs": [],
      "source": [
        "import logging\n",
        "import click\n",
        "from rich.logging import RichHandler\n",
        "\n",
        "LOGGER_NAME = \"custom-rag\"\n",
        "logging.basicConfig(\n",
        "    level=logging.INFO,\n",
        "    format=\"%(message)s\",\n",
        "    handlers=[RichHandler(rich_tracebacks=True, tracebacks_suppress=[click])],\n",
        ")\n",
        "logger = logging.getLogger(LOGGER_NAME)\n",
        "logging.getLogger(\"numexpr\").setLevel(logging.ERROR)\n",
        "logging.getLogger(\"httpx\").setLevel(logging.ERROR)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WXlIdY6YvqoX"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "\n",
        "import pandas as pd\n",
        "import pinecone\n",
        "import tqdm\n",
        "from langchain.chat_models import ChatVertexAI\n",
        "from langchain.document_loaders import WikipediaLoader\n",
        "from langchain.embeddings import VertexAIEmbeddings\n",
        "from langchain.llms import VertexAI\n",
        "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
        "from langchain.vectorstores import Pinecone\n",
        "\n",
        "from testset_generator import TestsetGenerator\n",
        "from rag import RAG\n",
        "\n",
        "pd.set_option(\"display.max_colwidth\", None)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gS9EDcFawG-6"
      },
      "source": [
        "\n",
        "### Load data from Wikipedia"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wPR-019mwK7S"
      },
      "outputs": [],
      "source": [
        "topic = \"python programming\"\n",
        "\n",
        "wikipedia_load = WikipediaLoader(\n",
        "    query=topic,\n",
        "    load_max_docs=1,\n",
        "    doc_content_chars_max=100000,\n",
        ")\n",
        "docs = wikipedia_load.load()\n",
        "doc = docs[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5GSygc_Cwe0w"
      },
      "source": [
        "### Index data into Pinecone"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "505e-f88wb0P"
      },
      "outputs": [],
      "source": [
        "index_name = topic.replace(\" \", \"-\")\n",
        "\n",
        "pinecone.init(\n",
        "    api_key=os.environ.get(\"PINECONE_API_KEY\"),\n",
        "    environment=os.environ.get(\"PINECONE_ENV\"),\n",
        ")\n",
        "\n",
        "\n",
        "if index_name in pinecone.list_indexes():\n",
        "    pinecone.delete_index(index_name)\n",
        "\n",
        "pinecone.create_index(index_name, dimension=768)\n",
        "\n",
        "index = pinecone.Index(index_name)\n",
        "\n",
        "logger.info(f\"Index {index_name} created successfully\")\n",
        "logger.info(index.describe_index_stats())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "amwTmYlNw5Yz"
      },
      "outputs": [],
      "source": [
        "CHUNK_SIZE = 512\n",
        "CHUNK_OVERLAP = 128\n",
        "\n",
        "splitter = RecursiveCharacterTextSplitter(\n",
        "    chunk_size=CHUNK_SIZE,\n",
        "    chunk_overlap=CHUNK_OVERLAP,\n",
        "    separators=[\". \"],\n",
        ")\n",
        "\n",
        "splits = splitter.split_documents([doc])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0598NNJVw81I"
      },
      "outputs": [],
      "source": [
        "embedding_model = VertexAIEmbeddings()\n",
        "docsearch = Pinecone.from_documents(\n",
        "    splits,\n",
        "    embedding_model,\n",
        "    index_name=index_name,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fyZVJY4mxEcR"
      },
      "source": [
        "### Create synthetic dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eccXIq7CxHSz"
      },
      "outputs": [],
      "source": [
        "generator_llm = VertexAI(\n",
        "    location=\"europe-west3\",\n",
        "    max_output_tokens=256,\n",
        "    max_retries=20,\n",
        ")\n",
        "embedding_model = VertexAIEmbeddings()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6qmypks6zwwJ"
      },
      "outputs": [],
      "source": [
        "testset_generator = TestsetGenerator(\n",
        "    generator_llm=generator_llm,\n",
        "    documents=splits,\n",
        "    embedding_model=embedding_model,\n",
        "    index_name=index_name,\n",
        "    key=\"text\",\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Aa_AnTwgz7SF"
      },
      "outputs": [],
      "source": [
        "synthetic_dataset = testset_generator.generate(\n",
        "    test_size=10,\n",
        "    num_questions_per_context=2,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "64o216ehz94V"
      },
      "outputs": [],
      "source": [
        "for i, row in synthetic_dataset.sample(3).iterrows():\n",
        "    print(f\"question: {row['question']}\")\n",
        "    print(f\"answer: {row['ground_truths']}\")\n",
        "    print(\"\\n====\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A2xtSqCi0BV4"
      },
      "source": [
        "### Generate Answers with the RAG"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HpFAnQGz0Agy"
      },
      "outputs": [],
      "source": [
        "llm = VertexAI(\n",
        "    model_name=\"text-bison\",\n",
        "    max_output_tokens=256,\n",
        "    temperature=0,\n",
        "    top_p=0.95,\n",
        "    top_k=40,\n",
        "    verbose=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F7ZTQ-K60Fgz"
      },
      "outputs": [],
      "source": [
        "rag = RAG(\n",
        "    index_name,\n",
        "    \"text-bison\",\n",
        "    embedding_model,\n",
        "    \"text\",\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wcdyAeDr0N70"
      },
      "outputs": [],
      "source": [
        "rag_answers = []\n",
        "contexts = []\n",
        "\n",
        "for i, row in tqdm.tqdm(synthetic_dataset.iterrows(), total=len(synthetic_dataset)):\n",
        "    question = row[\"question\"]\n",
        "    prediction = rag.predict(question)\n",
        "\n",
        "    rag_answer = prediction[\"answer\"]\n",
        "    rag_answers.append(rag_answer)\n",
        "    source_documents = prediction[\"source_documents\"]\n",
        "    contexts.append([s.page_content for s in source_documents])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9NogHhtC0QU1"
      },
      "outputs": [],
      "source": [
        "synthetic_dataset_rag = synthetic_dataset.copy()\n",
        "synthetic_dataset_rag[\"answer\"] = rag_answers\n",
        "synthetic_dataset_rag[\"contexts\"] = contexts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Eu2qcdda0R_O"
      },
      "outputs": [],
      "source": [
        "synthetic_dataset_rag.sample(1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YTfIBRh30YIf"
      },
      "source": [
        "### Evaluate synthetic dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uRzpEbi60T60"
      },
      "outputs": [],
      "source": [
        "from datasets import Dataset\n",
        "from ragas import evaluate\n",
        "from ragas.llms import LangchainLLM\n",
        "from ragas.metrics import (\n",
        "    answer_correctness,\n",
        "    answer_relevancy,\n",
        "    answer_similarity,\n",
        "    context_precision,\n",
        "    context_recall,\n",
        "    context_relevancy,\n",
        "    faithfulness,\n",
        ")\n",
        "\n",
        "synthetic_ds_rag = Dataset.from_pandas(synthetic_dataset_rag)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WUinNERz0hTI"
      },
      "outputs": [],
      "source": [
        "generator_llm = VertexAI(max_output_tokens=256, max_retries=10)\n",
        "ragas_vertexai_llm = LangchainLLM(llm=generator_llm)\n",
        "vertexai_embeddings = VertexAIEmbeddings()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "INNSTvgJ0jFt"
      },
      "outputs": [],
      "source": [
        "metrics = [\n",
        "    answer_relevancy,\n",
        "    context_precision,\n",
        "    faithfulness,\n",
        "    answer_correctness,\n",
        "    answer_similarity,\n",
        "]\n",
        "\n",
        "for m in metrics:\n",
        "    m.__setattr__(\"llm\", ragas_vertexai_llm)\n",
        "    if hasattr(m, \"embeddings\"):\n",
        "        m.__setattr__(\"embeddings\", vertexai_embeddings)\n",
        "\n",
        "answer_correctness.faithfulness = faithfulness\n",
        "answer_correctness.answer_similarity = answer_similarity"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0AO-eEFT0kUz"
      },
      "outputs": [],
      "source": [
        "results_rag = evaluate(\n",
        "    synthetic_ds_rag,\n",
        "    metrics=[\n",
        "        answer_relevancy,\n",
        "        context_precision,\n",
        "        faithfulness,\n",
        "        answer_correctness,\n",
        "    ],\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WHEsMQ8f0n5h"
      },
      "outputs": [],
      "source": [
        "print(results_rag)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aOClc6SL0pM4"
      },
      "outputs": [],
      "source": [
        "results_rag.to_pandas()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
